library(redist)
library(tidyverse)
library(maptools)
library(doMC)
library(igraph)
library(parallel)
source("../random_submap.R")

ben_theme <- function(){
    theme_classic() +
        theme(panel.background = element_blank(),
              panel.grid.major = element_blank(),
              panel.grid.minor = element_blank(),
              axis.line = element_line(colour = "black"),
              panel.border = element_rect(colour = "black", fill = NA, size = 1),
              strip.background = element_blank(),
              legend.position = "bottom", legend.title = element_blank(),
              plot.title = element_text(hjust = 0.5))
}
ipw <- function(x, beta, pop){
    indpop <- which(x$distance_parity <= pop)
    indbeta <- which(x$beta_sequence == beta)
    inds <- intersect(indpop, indbeta)
    psi <- x$constraint_pop[inds]
    w <- 1 / exp(beta * psi)
    inds <- sample(inds, length(inds), replace = TRUE, prob = w)
    x <- x$partitions[,inds]
    return(x)
}

## Load map
nh <- readShapePoly("../../data/nh/nh_final.shp")

## Set parameters
params <- expand.grid(
    nprecs = 100,
    ndists = 2,
    nsims = 1:200
)

appx_sims <- 10000000

dir.create("../../data/qq_test_sample", showWarnings = FALSE)

## Get info on job
i <- as.numeric(Sys.getenv("SLURM_ARRAY_TASK_ID"))

## -----------------
## Sim in each array
## -----------------

## Get parameters
nprecs <- params$nprecs[i]
ndists <- params$ndists[i]
alg <- params$alg[i]

## Look for map with smaller frontier
repeat{

    ## Sample map
    nh_sub <- random_submap(nh, nprecs)
    nh_sub@data$POP100 <- nh_sub@data$POP100 + 1
    nh_sub@data$PRES_RVOTE <- nh_sub@data$PRES_RVOTE + 1
    
    ## Convert shp file to adjacency list
    adjlist <- poly2nb(nh_sub, queen = FALSE)

    ## Sink
    adjlist_map <- c()
    for(k in 1:length(adjlist)){
        sub <- adjlist[[k]]
        sub <- sub[sub > k]
        if(length(sub) > 0){
            for(l in 1:length(sub)){
                adjlist_map <- rbind(adjlist_map, c(k, sub[l]))
            }
        }
    }
    write_delim(data.frame(adjlist_map),
                path = paste0("../../data/qq_test_sample/adjlist_map_", i, ".dat"),
                col_names = FALSE)

    ## Order edges
    system(paste0("python ../ndscut.py <../../data/qq_test_sample/adjlist_map_", i, ".dat >../../data/qq_test_sample/adjlist_map_ordered_", i, ".dat"))

    ## Calculate frontier
    outp <- system(paste0("python ../../enumpart_private/frontier_size/calc_frontier_size.py < ../../data/qq_test_sample/adjlist_map_ordered_", i, ".dat"), intern = TRUE)
    maxf <- parse_number(outp[1])

    if(maxf <= 12){
        break
    }

}

## Run enumpart
system(paste0("../../enumpart_private/enumpart ../../data/qq_test_sample/adjlist_map_ordered_", i, ".dat -k ", ndists, " -sample 1000000 >/n/imai_lab/bfifield/enumeration/qq_test_sample/adjlist_map_sols_", i, ".dat"))

## Load solutions, get true distribution
sols <- readLines(paste0("/n/imai_lab/bfifield/enumeration/qq_test_sample/adjlist_map_sols_", i, ".dat"))
el <- strsplit(readLines(paste0("../../data/qq_test_sample/adjlist_map_ordered_", i, ".dat")), split = " ")
el <- apply(do.call(rbind, el), 2, as.numeric)
sols_out <- mclapply(1:length(sols), function(x){
    sol_split <- as.numeric(strsplit(sols[x], split = " ")[[1]])
    el_sub <- el[sol_split,]
    comps_out <- components(graph_from_edgelist(el_sub, directed = FALSE))
    if(comps_out$no < ndists){
        max_num <- max(comps_out$membership)
        comps_out <- c(comps_out$membership, (max_num + 1):ndists)
    }else{
        comps_out <- comps_out$membership
    }
    return(comps_out)
}, mc.cores = detectCores())
sols_out <- do.call(cbind, sols_out)

target <- redist.segcalc(sols_out, grouppop = nh_sub@data$PRES_RVOTE, 
                         fullpop = nh_sub@data$POP100)
popdist <- apply(sols_out, 2, function(x){
    distpop <- tapply(nh_sub@data$POP100, x, sum)
    parpop <- sum(distpop) / length(distpop)
    max(abs(distpop / parpop - 1))
})

## --------------------
## Run algorithm - MCMC
## --------------------
thin_amt <- 1000
thin_chain <- function(x, thin = thin_amt){
    inds <- seq(1, ncol(x$partitions), by = thin)
    x_new <- vector(mode = "list", length = length(x))
    for(i in 1:length(x)){

        ## Subset the matrix first, then the vectors
        if(i == 1){
            x_new[[i]] <- x[[i]][,inds]
        }else{
            x_new[[i]] <- x[[i]][inds]
        }
        
    }
    names(x_new) <- names(x)
    class(x_new) <- "redist"
    return(x_new)
}

## No parity
mcmc_out <- redist.mcmc(nh_sub, nh_sub@data$POP100, nsims = appx_sims,
                        ndists = ndists,
                        maxiterrsg = 500000)
mcmc_seg_full <- redist.segcalc(thin_chain(mcmc_out),
                                grouppop = nh_sub@data$PRES_RVOTE, 
                                fullpop = nh_sub@data$POP100)

## 5% parity
mcmc_out <- redist.mcmc(nh_sub, nh_sub@data$POP100, nsims = appx_sims,
                        ndists = ndists, constraint = "population",
                        beta = -25,
                        maxiterrsg = 500000)
mcmc_out <- ipw(thin_chain(mcmc_out), beta = -25, pop = .05)
mcmc_seg_05 <- redist.segcalc(mcmc_out, grouppop = nh_sub@data$PRES_RVOTE, 
                               fullpop = nh_sub@data$POP100)

## 1% parity
mcmc_out <- redist.mcmc(nh_sub, nh_sub@data$POP100, nsims = appx_sims,
                        ndists = ndists, constraint = "population",
                        beta = -50,
                        maxiterrsg = 500000)
mcmc_out <- ipw(thin_chain(mcmc_out), beta = -50, pop = .01)
mcmc_seg_01 <- redist.segcalc(mcmc_out, grouppop = nh_sub@data$PRES_RVOTE, 
                               fullpop = nh_sub@data$POP100)

## -------------------
## Run algorithm - RSG
## -------------------

adjlist_zind <- lapply(adjlist, function(y){y-1})
nsims_rsg <- appx_sims / thin_amt

## No parity
sv_out <- mclapply(1:nsims_rsg, function(x){
    if(x %% (appx_sims / 10) == 0){
        print(paste0("Done with ", x, " no parity iterations at ", Sys.time(), "."))
    }
    redist.rsg(
        adjlist_zind,
        nh_sub@data$POP100, ndists = ndists, thresh = 100,
        verbose = FALSE
    )$district_membership
}, mc.cores = detectCores())
rsg_out <- do.call(cbind, sv_out)
rsg_seg_full <- redist.segcalc(rsg_out, grouppop = nh_sub@data$PRES_RVOTE, 
                               fullpop = nh_sub@data$POP100)

## 5% parity
sv_out <- mclapply(1:nsims_rsg, function(x){
    if(x %% (appx_sims / 10) == 0){
        print(paste0("Done with ", x, " 5% iterations at ", Sys.time(), "."))
    }
    redist.rsg(
        adjlist_zind,
        nh_sub@data$POP100, ndists = ndists, thresh = .05,
        maxiter = 500000, verbose = FALSE
    )$district_membership
}, mc.cores = detectCores())
rsg_out <- do.call(cbind, sv_out)
rsg_seg_05 <- redist.segcalc(rsg_out, grouppop = nh_sub@data$PRES_RVOTE, 
                              fullpop = nh_sub@data$POP100)

## 1% parity
sv_out <- mclapply(1:nsims_rsg, function(x){
    if(x %% (appx_sims / 10) == 0){
        print(paste0("Done with ", x, " 1% iterations at ", Sys.time(), "."))
    }
    redist.rsg(
        adjlist_zind,
        nh_sub@data$POP100, ndists = ndists, thresh = .01,
        maxiter = 500000, verbose = FALSE
    )$district_membership
}, mc.cores = detectCores())
rsg_out <- do.call(cbind, sv_out)
rsg_seg_01 <- redist.segcalc(rsg_out, grouppop = nh_sub@data$PRES_RVOTE, 
                              fullpop = nh_sub@data$POP100)

## -------------
## Create output
## -------------
out <- data.frame(
    ks_stat = c(
        ks.test(target, mcmc_seg_full)$statistic,
        ks.test(target[popdist <= .05], mcmc_seg_05)$statistic,
        ks.test(target[popdist <= .01], mcmc_seg_01)$statistic,
        ks.test(target, rsg_seg_full)$statistic,
        ks.test(target[popdist <= .05], rsg_seg_05)$statistic,
        ks.test(target[popdist <= .01], rsg_seg_01)$statistic
    ),
    ks_pval = c(
        ks.test(target, mcmc_seg_full)$p.value,
        ks.test(target[popdist <= .05], mcmc_seg_05)$p.value,
        ks.test(target[popdist <= .01], mcmc_seg_01)$p.value,
        ks.test(target, rsg_seg_full)$p.value,
        ks.test(target[popdist <= .05], rsg_seg_05)$p.value,
        ks.test(target[popdist <= .01], rsg_seg_01)$p.value
    )
)
out$nprecs <- nprecs
out$ndists <- paste0(ndists, " Districts")
out$alg <- c(rep("MCMC", 3), rep("RSG", 3))

out$parity <- rep(c("No Equal Population Constraint",
                    "5% Equal Population Constraint",
                    "1% Equal Population Constraint"), 2)

write_csv(out, path = paste0("../../data/qq_test_sample/ks_test_", i, ".csv"))
